# flake8: noqa
# Original example from: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/mnist.py
from dataclasses import dataclass
from typing import Any

import hydra
import pytorch_lightning as pl
import torch
from hydra.core.config_store import ConfigStore
from hydra.utils import instantiate
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets.mnist import MNIST

from hydra_configs.pytorch_lightning.trainer import TrainerConf
from hydra_configs.torch.optim import AdamConf

# ====== NOTE: HYDRA BLOCK =========
# structured config imports


@dataclass
class LitClassifierConf:
    trainer: TrainerConf = TrainerConf()
    optim_conf: Any = AdamConf()
    hidden_dim: int = 128
    data_shape: int = 1 * 28 * 28
    target_shape: int = 1 * 10
    root_dir: str = "."
    seed: int = 1234


cs = ConfigStore.instance()
cs.store(name="litconf", node=LitClassifierConf)
# ====== / HYDRA BLOCK =========


class LitClassifier(pl.LightningModule):
    def __init__(
        self,
        data_shape: int = 1 * 28 * 28,
        hidden_dim: int = 128,
        target_shape: int = 1 * 10,
        learning_rate: float = 1e-3,
        **kwargs  # NOTE: if you want hparams to contain/log your whole cfg, this is important
    ):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(data_shape, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, target_shape)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam(lr=self.hparams.learning_rate, params=self.parameters())


@hydra.main(config_name="litconf")
def cli_main(cfg):
    # NOTE: this is needed so that data is only downloaded once
    cfg.root_dir = hydra.utils.get_original_cwd()
    print(cfg.pretty())
    pl.seed_everything(cfg.seed)

    # ------------
    # data
    # ------------
    dataset = MNIST(
        root=cfg.root_dir, train=True, download=True, transform=transforms.ToTensor()
    )
    mnist_test = MNIST(
        root=cfg.root_dir, train=False, download=True, transform=transforms.ToTensor()
    )
    mnist_train, mnist_val = random_split(dataset, [55000, 5000])

    train_loader = DataLoader(mnist_train, **cfg.dataloader)
    val_loader = DataLoader(mnist_val, **cfg.dataloader)
    test_loader = DataLoader(mnist_test, **cfg.dataloader)

    # ------------
    # model
    # ------------
    # NOTE: We use LitClassifier(**cfg) since LitClassifier is not an autogenerated hydra-lightning config class and therefore does not have a `__target__` field. One could add this in order to use `hydra.utils.instantiate(cfg.trainer)`.
    model = LitClassifier(**cfg)

    # ------------
    # training
    # ------------
    # NOTE: Here we use `hydra.utils.instantiate(cfg.trainer)` because the trainer conf is an autogenerated hydra-lightning config. This is additionally useful as it support recursive instnantiation (for example adding `Callback`s).
    trainer = instantiate(cfg.trainer)
    trainer.fit(model, train_loader, val_loader)

    # ------------
    # testing
    # ------------
    trainer.test(test_dataloaders=test_loader)


if __name__ == "__main__":
    cli_main()
